#include <errno.h>
#include <seminix/idtable.h>
#include <seminix/slab.h>
#include <seminix/cache.h>

static inline unsigned long *alloc_bits(int nr)
{
    return kzalloc(max(nr / BITS_PER_BYTE, cache_line_size()), GFP_KERNEL);
}

static int __id_table_init(struct idtable *it, int nr, int first)
{
    it->full_bits = alloc_bits(nr);
    if (!it->full_bits)
        return -ENOMEM;
    it->used = 0;
    it->max = nr;
    if (first < 0 || first > nr)
        first = 0;
    it->first = first;
    return 0;
}

int id_table_init(struct idtable *it, int nr, int first)
{
    if (!it || nr <= 0)
        return -EINVAL;
    return __id_table_init(it, nr, first);
}

struct idtable *id_table_create(int nr, int first)
{
    struct idtable *it;

    it = kmalloc(sizeof (*it), GFP_KERNEL);
    if (!it)
        return NULL;

    if (__id_table_init(it, nr, first)) {
        kfree(it);
        return NULL;
    }

    return it;
}

void id_table_destroy(struct idtable *it, bool free_table)
{
    kfree(it->full_bits);
    if (free_table)
        kfree(it);
}

static void copy_idtable(unsigned long *new, unsigned long *old, int old_max)
{
    u32 cpy;

    cpy = max(old_max / BITS_PER_BYTE, cache_line_size());
    memcpy(new, old, cpy);
}

static int __id_table_get_free(struct idtable *it, int start, bool expand)
{
    int id = start;

    if (!it)
        goto fail;
    if (start >= it->max && !expand)
        goto fail;

again:
    start = id;
    if (start < it->first)
        start = it->first;

    start = find_next_zero_bit(it->full_bits, it->max, start);
    if (start > it->max)
        start = it->max;

    if (start < it->max)
        goto find;

    if (expand) {
        unsigned long *new;

        new = alloc_bits(it->max * 2);
        if (!new)
            return -1;
        copy_idtable(new, it->full_bits, it->max);
        kfree(it->full_bits);
        it->full_bits = new;
        it->max *= 2;
        goto again;
    }

fail:
    return -1;
find:
    set_bit(start, it->full_bits);
    it->first = start + 1;
    it->used++;
    return start;
}

int id_table_get_free(struct idtable *it, int start)
{
    return __id_table_get_free(it, start, true);
}

int id_talbe_get_free_noexpand(struct idtable *it, int start)
{
    return __id_table_get_free(it, start, false);
}

void id_table_put(struct idtable *it, int nr)
{
    if (!it || nr >= it->max)
        return;
    clear_bit(nr, it->full_bits);
    if (nr < it->first)
        it->first = nr;
    it->used--;
}
